{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import os\n",
    "\n",
    "import nemo\n",
    "from nemo.utils.lr_policies import WarmupAnnealing\n",
    "\n",
    "import nemo.collections.nlp as nemo_nlp\n",
    "from nemo.collections.nlp.data import NemoBertTokenizer, SentencePieceTokenizer\n",
    "from nemo.collections.nlp.callbacks.token_classification_callback import \\\n",
    "    eval_iter_callback, eval_epochs_done_callback\n",
    "from nemo.backends.pytorch.common.losses import CrossEntropyLossNM\n",
    "from nemo.collections.nlp.nm.trainables import TokenClassifier\n",
    "from nemo import logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can download data from [here](https://github.com/kyzhouhzau/BERT-NER/tree/master/data) and use [this](https://github.com/NVIDIA/NeMo/blob/master/examples/nlp/token_classification/import_from_iob_format.py) script to preprocess it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCHES_PER_STEP = 1\n",
    "BATCH_SIZE = 32\n",
    "CLASSIFICATION_DROPOUT = 0.1\n",
    "DATA_DIR = \"PATH TO WHERE THE DATA IS\"\n",
    "WORK_DIR = \"PATH_TO_WHERE_TO_STORE_CHECKPOINTS_AND_LOGS\"\n",
    "MAX_SEQ_LENGTH = 128\n",
    "NUM_EPOCHS = 3\n",
    "LEARNING_RATE = 0.00005\n",
    "LR_WARMUP_PROPORTION = 0.1\n",
    "PRETRAINED_BERT_MODEL = \"bert-base-cased\"\n",
    "OPTIMIZER = \"adam\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Instantiate neural factory with supported backend\n",
    "neural_factory = nemo.core.NeuralModuleFactory(\n",
    "    backend=nemo.core.Backend.PyTorch,\n",
    "\n",
    "    # If you're training with multiple GPUs, you should handle this value with\n",
    "    # something like argparse. See examples/nlp/token_classification.py for an example.\n",
    "    local_rank=None,\n",
    "\n",
    "    # If you're training with mixed precision, this should be set to mxprO1 or mxprO2.\n",
    "    # See https://nvidia.github.io/apex/amp.html#opt-levels for more details.\n",
    "    optimization_level=\"O0\",\n",
    "    \n",
    "    # Define path to the directory you want to store your results\n",
    "    log_dir=WORK_DIR,\n",
    "\n",
    "    # If you're training with multiple GPUs, this should be set to\n",
    "    # nemo.core.DeviceType.AllGpu\n",
    "    placement=nemo.core.DeviceType.GPU)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If you're using a standard BERT model, you should do it like this. To see the full\n",
    "# list of BERT/ALBERT/RoBERTa model names, call nemo_nlp.nm.trainables.get_bert_models_list()\n",
    "\n",
    "tokenizer = NemoBertTokenizer(pretrained_model=PRETRAINED_BERT_MODEL)\n",
    "bert_model = nemo_nlp.nm.trainables.get_huggingface_model(pretrained_model_name=PRETRAINED_BERT_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Describe training DAG\n",
    "train_data_layer = nemo_nlp.nm.data_layers.BertTokenClassificationDataLayer(\n",
    "        tokenizer=tokenizer,\n",
    "        text_file=os.path.join(DATA_DIR, 'text_train.txt'),\n",
    "        label_file=os.path.join(DATA_DIR, 'labels_train.txt'),\n",
    "        max_seq_length=MAX_SEQ_LENGTH,\n",
    "        batch_size=BATCH_SIZE)\n",
    "\n",
    "label_ids = train_data_layer.dataset.label_ids\n",
    "num_classes = len(label_ids)\n",
    "\n",
    "hidden_size = bert_model.hidden_size\n",
    "ner_classifier = TokenClassifier(hidden_size=hidden_size,\n",
    "                                          num_classes=num_classes,\n",
    "                                          dropout=CLASSIFICATION_DROPOUT)\n",
    "\n",
    "ner_loss = CrossEntropyLossNM(logits_ndim=3)\n",
    "\n",
    "input_ids, input_type_ids, input_mask, loss_mask, _, labels = train_data_layer()\n",
    "\n",
    "hidden_states = bert_model(input_ids=input_ids,\n",
    "                           token_type_ids=input_type_ids,\n",
    "                           attention_mask=input_mask)\n",
    "\n",
    "logits = ner_classifier(hidden_states=hidden_states)\n",
    "loss = ner_loss(logits=logits, labels=labels, loss_mask=loss_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Describe evaluation DAG\n",
    "eval_data_layer = nemo_nlp.nm.data_layers.BertTokenClassificationDataLayer(\n",
    "        tokenizer=tokenizer,\n",
    "        text_file=os.path.join(DATA_DIR, 'text_dev.txt'),\n",
    "        label_file=os.path.join(DATA_DIR, 'labels_dev.txt'),\n",
    "        max_seq_length=MAX_SEQ_LENGTH,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        label_ids=label_ids)\n",
    "\n",
    "eval_input_ids, eval_input_type_ids, eval_input_mask, _, eval_subtokens_mask, eval_labels \\\n",
    "    = eval_data_layer()\n",
    "\n",
    "hidden_states = bert_model(\n",
    "    input_ids=eval_input_ids,\n",
    "    token_type_ids=eval_input_type_ids,\n",
    "    attention_mask=eval_input_mask)\n",
    "\n",
    "eval_logits = ner_classifier(hidden_states=hidden_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "callback_train = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[loss],\n",
    "    print_func=lambda x: logging.info(\"Loss: {:.3f}\".format(x[0].item())))\n",
    "\n",
    "train_data_size = len(train_data_layer)\n",
    "\n",
    "# If you're training on multiple GPUs, this should be\n",
    "# train_data_size / (batch_size * batches_per_step * num_gpus)\n",
    "steps_per_epoch = int(train_data_size / (BATCHES_PER_STEP * BATCH_SIZE))\n",
    "\n",
    "# Callback to evaluate the model\n",
    "callback_eval = nemo.core.EvaluatorCallback(\n",
    "    eval_tensors=[eval_logits, eval_labels, eval_subtokens_mask],\n",
    "    user_iter_callback=lambda x, y: eval_iter_callback(x, y),\n",
    "    user_epochs_done_callback=lambda x: eval_epochs_done_callback(x, label_ids),\n",
    "    eval_step=steps_per_epoch)\n",
    "\n",
    "# Callback to store checkpoints\n",
    "# Checkpoints will be stored in checkpoints folder inside WORK_DIR\n",
    "ckpt_callback = nemo.core.CheckpointCallback(\n",
    "    folder=neural_factory.checkpoint_dir,\n",
    "    epoch_freq=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr_policy = WarmupAnnealing(NUM_EPOCHS * steps_per_epoch,\n",
    "                            warmup_ratio=LR_WARMUP_PROPORTION)\n",
    "neural_factory.train(\n",
    "    tensors_to_optimize=[loss],\n",
    "    callbacks=[callback_train, callback_eval, ckpt_callback],\n",
    "    lr_policy=lr_policy,\n",
    "    batches_per_step=BATCHES_PER_STEP,\n",
    "    optimizer=OPTIMIZER,\n",
    "    optimization_params={\n",
    "        \"num_epochs\": NUM_EPOCHS,\n",
    "        \"lr\": LEARNING_RATE\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
